
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
import pandas as pd
from sklearn.metrics import roc_curve, auc
from pathlib import Path
from matplotlib.pyplot import savefig  
from matplotlib.pyplot import figaspect  
import csv
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_curve

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 



def plotting_rocs(t, tokens,cutter, kick_weird, where, model_type, longer,g,comp, dpdf):

    readin = comp

    
    more = model_type

    print(more)

    lines = [':', '--', '-', '-', '-.', ':', '--', ':', '-.',':', '--', '-']
           
    colors = ['darkgreen', 'darkorange', 'navy', 'gray', 'red', 'pink', 'blue','yellow','pink','springgreen','indigo']

    if model_type == 'logit_':
        labels = ['Text', 'History', 'Text & history', 'Standard', 'Commodities,' 'ICEWS', 'Text & standard & history', 'Text & ICEWS & history']
        
        preders = ['text', 'dummies', 'mine', 'silverstone', 'commodity', 'ICEWS', 'mine_gold', 'mine_ICEWS']
    elif model_type[-7:] == 'single_':
        labels = ['Neighbor', 'Logit', 'Neural', 'AdaBoost', 'RandomForest']
        preders = ['neighbor', 'Logit', 'neural', 'adaboost', 'randomforest']
       
    elif longer == 1:
        
        labels = ['Text', 'History', 'Text & history', 'Standard', 'Text & standard & history', 'Commodities', 'Text & commodities & history', 'Comm. index', 'Text & comm. index & history', 'Comm. & continent', 'Text & comm. & continent & history']     
        preders = ['text', 'dummies', 'mine','silverstone', 'mine_gold', 'commodity', 'commodity_alone', 'commodity_short', 'commodity_short_alone', 'commodity_region', 'commodity_region_alone']
    else:
        
        labels = ['Text', 'History', 'Text & history']
        preders = ['text', 'dummies', 'mine']

    
    events = dpdf

    hurz = 0
    hurzx = 0


    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    roc_auc_year = dict()


    pf = dict()
    pt = dict()
    precision = dict()

    so = np.linspace(0,1)


    my_file = Path(readin +  more + tokens +  '_' +str(cutter)+  '_' +str(kick_weird)+ '_' + where+  '_topics' + str(t) + '_' + str(longer) + '_' + str(g) +'.csv')
    
    data = pd.read_table(my_file, sep=',')

    #for col in data.columns: 
    #    print(col) 

    #first and last year
    first = data.year.min()
    last = data.year.max()

    plt.rc('xtick', labelsize=10) 
    plt.rc('ytick', labelsize=10)

    if longer == 0:
        conf_types = ['ons_','']
        
    else:
        conf_types = ['ons_']

    for conf_type in conf_types:

        d = -1

            


        # #pairwise comparison

        labelsx = []
        colorsx = []
        linesx = []
        predersx = []

        uff = [':', '--', '-',  '-.']
        linesx.append(uff)            
        linesx.append(uff)
        linesx.append(uff)
        linesx.append(uff)
        linesx.append(uff)
        linesx.append(uff)
        linesx.append(uff)
        linesx.append(uff)
       
        uff = ['darkgreen', 'darkorange', 'navy', 'red']
        colorsx.append(uff)
        uff = ['darkgreen', 'navy', 'darkorange', 'red']
        colorsx.append(uff)
        colorsx.append(uff)
        colorsx.append(uff)
        colorsx.append(uff)
        colorsx.append(uff)
        colorsx.append(uff)
        colorsx.append(uff)


        if longer == 1:
            uff = ['Text', 'History', 'Text & history'] 
            labelsx.append(uff)           
            uff = ['Text', 'History', 'Standard',  'Text & standard & history']
            labelsx.append(uff)
            uff = ['Text', 'History', 'Commodities',  'Text & commodities & history']
            labelsx.append(uff)
            uff = ['Text', 'History', 'Comm. index',  'Text & comm. index & history']
            labelsx.append(uff)
            uff = ['Text', 'History', 'Comm. & continent',  'Text & comm. & continent & history']
            labelsx.append(uff)
            
            uff = ['text', 'dummies', 'mine']
            predersx.append(uff)
            uff = ['text',  'dummies', 'silverstone', 'mine_gold']
            predersx.append(uff)
            uff = ['text',  'dummies', 'commodity_alone', 'commodity']
            predersx.append(uff)
            uff = ['text',  'dummies', 'commodity_short_alone', 'commodity_short']
            predersx.append(uff)
            uff = ['text',  'dummies', 'commodity_region_alone', 'commodity_region']
            predersx.append(uff)
        elif model_type[-7:] == 'single_':
            
            colorsx = []
            linesx = []
            uff = ['Neighbor', 'Logit', 'Neural', 'AdaBoost', 'RandomForest'] 
            labelsx.append(uff)                     
            uff = ['neighbor', 'Logit', 'neural', 'adaboost', 'randomforest']
            predersx.append(uff)
            uff = ['darkgreen', 'darkorange', 'navy', 'red', 'gray']
            colorsx.append(uff)
            uff = [':', '-', '--',  '-.', ':']
            linesx.append(uff)

       
        else:
            
            uff = ['Text', 'History', 'Text & history'] 
            labelsx.append(uff) 
                    
            
            uff = ['text', 'dummies', 'mine']
            predersx.append(uff)
            

            



    events = dpdf


    
    if longer == 1:
        until = 1+longer*4
    else:
        until = 1
    


    for pairs in range(0,until):  

        

        colors = colorsx[pairs]
        lines = linesx[pairs]
        labels = labelsx[pairs] 
        preders = predersx[pairs]     

        print(colors)
        print(labels)
        print(preders)         

        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        roc_auc_year = dict()

        roc_auc_CI_u = dict()
        roc_auc_CI_d = dict()

        precision_u = dict()
        precision_d = dict()

        pf = dict()
        pt = dict()
        precision = dict()

        plt.rc('xtick', labelsize=10) 
        plt.rc('ytick', labelsize=10)

        

        for conf_type in conf_types:

            d = -1

            for dup in dpdf:

                
     
                dep = conf_type + dup

                dp = dup + '_dp'

    

                dep = dep + str(g)

           

                d = d + 1

                for p in range(0,2):

                    

                    preds = data.drop(data[(data.dep != dep)].index)

                    if g > 1:
                        
                        war = 'lag_' + dep
                        preds.sort_values(by=['isocode','year','month'], inplace=True)
                        
                        preds[war] = preds.groupby('isocode')[dep].shift(1)
                        
                        preds['lag_isocode'] = preds.groupby('isocode')['isocode'].shift(1)
                        preds = preds.drop(preds[(preds[war] == 1) & (preds['lag_isocode'] == preds['isocode'])].index)

                    


                    if longer != 2:
                        preds['diplomacy'] = preds[dp]
                    else:
                        preds['diplomacy'] = preds['anyviolence_dp']

                    if p == 1:

                        

                        preds = preds.drop(preds[(preds.diplomacy <= 120)].index)

                    elif p == 2:



                        preds = preds.drop(preds[(preds.diplomacy > 120)].index)
      

                    doof = [dep,'year']
                    for ei in preders:
                            doof.append(ei)

                    
                    preds = preds[doof]

                    
                    preds = preds.dropna(how='any')
                    outcome = preds[[dep]]
                    
                    preds0 = preds[preders]
               
                        

                    outcomex = outcome.values
                    predsx = preds0.values

                    print('observations pariwise p',p, ' ', len(outcomex), ' ', conf_type, ' ', dup)

                    outcomes_original = outcomex[:]
                    preds_original = predsx[:,:]

                    n_bootstraps = 1000
                    rng_seed = 42  # control reproducibility                        
                    rng = np.random.RandomState(rng_seed)


                    for i in range(0,len(labels)):
                        #bootstrap for 95% CI
                        bootstrapped_scores = []
                        bootstrapped_precision = []
                        for j in range(n_bootstraps):
                            # bootstrap by sampling with replacement on the prediction indices
                            indices = rng.random_integers(0, len(predsx[:, i]) - 1, len(predsx[:, i]))                                                            
                            #indices = np.random.randint(0, len(predsx[:, i]), len(predsx[:, i]))
                            blub,blab, _ = roc_curve(outcomex[indices], predsx[indices, i])                                
                            score = auc(blub, blab)
                            bootstrapped_scores.append(score)
                            bootstrapped_precision.append(average_precision_score(outcomex[indices], predsx[indices, i]))
                            
                      

                        sorted_scores = np.array(bootstrapped_scores)
                        sorted_scores.sort()

                        confidence_lower = sorted_scores[int(0.025 * len(sorted_scores))]
                        confidence_upper = sorted_scores[int(0.975 * len(sorted_scores))]
                     
                        roc_auc_CI_d[d,i,p] = confidence_lower
                        roc_auc_CI_u[d,i,p] = confidence_upper


                        fpr[d,i,p], tpr[d,i,p], _ = roc_curve(outcomes_original[:], preds_original[:, i])
                        roc_auc[d,i,p] = auc(fpr[d,i,p], tpr[d,i,p])

                        pt[d,i,p], pf[d,i,p], _ = precision_recall_curve(outcomex[:], predsx[:, i])
                        precision[d,i,p] = average_precision_score(outcomex[:], predsx[:, i])

                        sorted_scores = np.array(bootstrapped_precision)
                        sorted_scores.sort()
                        precision_d[d,i,p] = sorted_scores[int(0.025 * len(sorted_scores))]
                        precision_u[d,i,p] = sorted_scores[int(0.975 * len(sorted_scores))]

        

          


            #set line widths
            ls = 3
            ts = 6
            lw = 1
            lts = 4.5

            plt.rc('xtick', labelsize=4) 
            plt.rc('ytick', labelsize=4)

           

            #precision
            axnum = 0

            

            if longer == 2:
                mytitle = ['Non-state','One-sided','State-based']
                myx = ['True Positive Rate','True Positive Rate','True Positive Rate']
                myy = ['Precision','','']
                many = 3
            else:
                mytitle = ['Any violence','Armed conflict']
                myx = ['True Positive Rate','True Positive Rate']
                myy = ['Precision', '']
                many = 2

            for p in range(0,1):
                for d in range(0,many):
                    axnum = axnum + 1
                    #axname = 'ax' + str(axnum)
                    #plt.subplot(1,3,axnum)
                    plt.subplot(1,many,axnum).set(adjustable='box', aspect='equal')
                    #plt.plot(so, so, color='gray', lw=0.25, linestyle='--')
                    for i in range(0,len(labels)):
                        plt.plot(pf[d,i,p], pt[d,i,p],
                            label=labels[i]+' {0:0.2f}'
                                   ''.format(precision[d,i,p]),
                            color=colors[i], linestyle=lines[i], linewidth=lw)

                        plt.ylabel(myy[axnum-1], fontsize=ts)
                        plt.xlabel(myx[axnum-1], fontsize=ts)
                        plt.legend(frameon=False,fancybox=True, framealpha=0.5,loc='upper right', fontsize=lts)
                        plt.title(mytitle[axnum-1], fontsize=ts)
                        

            savefig(readin + 'Figures_'+ where + '/' + more + '_' + conf_type + str(longer) +'_precision_'+kick_weird+'_'+str(cutter)+'_'+tokens+'_pairwise'+str(pairs)+'_'+str(t)+'_two'+str(g)+'.pdf', bbox_inches='tight')
            plt.close()



            #set line widths
            ls = 3
            ts = 6
            lw = 1
            lts = 4.5

            plt.rc('xtick', labelsize=4) 
            plt.rc('ytick', labelsize=4)

           
            #precision
            axnum = 0

           

            if longer == 2:
                mytitle = ['Non-state','One-sided','State-based']
                myx = ['True Positive Rate','True Positive Rate','True Positive Rate']
                myy = ['Precision','','']
                many = 3
            else:
                mytitle = ['Any violence','Armed conflict']
                myx = ['True Positive Rate','True Positive Rate']
                myy = ['Precision', '']
                many = 2

            for p in range(0,1):
                for d in range(0,many):
                    axnum = axnum + 1
                    
                    plt.subplot(1,many,axnum).set(adjustable='box', aspect='equal')
                    
                    for i in range(0,len(labels)):
                        plt.plot(pf[d,i,p], pt[d,i,p],
                            label=labels[i]+' {0:0.2f}'
                                   ''.format(precision[d,i,p]),
                            color=colors[i], linestyle=lines[i], linewidth=lw)

                        plt.ylabel(myy[axnum-1], fontsize=ts)
                        plt.xlabel(myx[axnum-1], fontsize=ts)
                        plt.legend(frameon=False,fancybox=True, framealpha=0.5,loc='upper right', fontsize=lts)
                        plt.title(mytitle[axnum-1], fontsize=ts)
                        

            savefig(readin + 'Figures_'+ where + '/' + more + '_' + conf_type + str(longer) +'_precision_'+kick_weird+'_'+str(cutter)+'_'+tokens+'_pairwise'+str(pairs)+'_'+str(t)+'_'+str(g)+'.pdf', bbox_inches='tight')
            plt.close()


           


            axnum = 0

            

            if longer == 2:
                mytitle = ['Non-state (all cases)','One-sided (all cases)','State-based (all cases)','Non-state (hard cases)','One-sided (hard cases)','State-based (hard cases)']
                myx = ['','','','False Positive Rate','False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '','','True Positive Rate','','']
                many = 3
            else:
                mytitle = ['Any violence (all cases)','Armed conflict (all cases)','Any violence (hard cases)','Armed conflict (hard cases)']
                myx = ['','','False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '', 'True Positive Rate','']
                many = 2

            for p in range(0,2):
                for d in range(0,many):
                    axnum = axnum + 1
                    
                    plt.subplot(2,many,axnum)
                    plt.plot(so, so, color='gray', lw=0.25, linestyle='--')
                    for i in range(0,len(labels)):
                        plt.plot(fpr[d,i,p], tpr[d,i,p],
                            label=labels[i]+' {0:0.2f}'
                                   ''.format(roc_auc[d,i,p]),
                            color=colors[i], linestyle=lines[i], linewidth=lw)

                        plt.ylabel(myy[axnum-1], fontsize=ts)
                        plt.xlabel(myx[axnum-1], fontsize=ts)
                        plt.legend(frameon=False,fancybox=True, framealpha=0.5,loc='lower right', fontsize=lts)
                        plt.title(mytitle[axnum-1], fontsize=ts)

            savefig(readin + 'Figures_'+ where + '/' + more + '_' + conf_type + str(longer) +'_AUC_'+kick_weird+'_'+str(cutter)+'_'+tokens+'_pairwise'+str(pairs)+'_'+str(t)+'_'+str(g)+'.pdf', bbox_inches='tight')
            plt.close()


            #set line widths
            ls = 4
            ts = 6
            lw = 2
            lts = 5.5

            plt.rc('xtick', labelsize=4) 
            plt.rc('ytick', labelsize=4)



            axnum = 0

           
            if longer == 2:
                mytitle = ['Non-state (all cases)','One-sided (all cases)','State-based (all cases)','Non-state (hard cases)','One-sided (hard cases)','State-based (hard cases)']
                myx = ['','','','False Positive Rate','False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '','','True Positive Rate','','']
                many = 3
            else:
                mytitle = ['Any violence (all cases)','Armed conflict (all cases)','Any violence (hard cases)','Armed conflict (hard cases)']
                myx = ['','','False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '', 'True Positive Rate','']

                many = 2

            for p in range(0,2):
                for d in range(0,many):
                    axnum = axnum + 1
                    #axname = 'ax' + str(axnum)
                    plt.subplot(2,many,axnum)
                    plt.plot(so, so, color='gray', lw=0.25, linestyle='--')
                    for i in range(0,len(labels)):
                        plt.plot(fpr[d,i,p], tpr[d,i,p],
                            label=labels[i]+' {0:0.2f}'
                                   ''.format(roc_auc[d,i,p])+' [{0:0.2f},'
                                   ''.format(roc_auc_CI_d[d,i,p])+'{0:0.2f}]'
                                   ''.format(roc_auc_CI_u[d,i,p]),
                            color=colors[i], linestyle=lines[i], linewidth=lw)

                        plt.ylabel(myy[axnum-1], fontsize=ts)
                        plt.xlabel(myx[axnum-1], fontsize=ts)
                        plt.legend(frameon=False,fancybox=True, framealpha=0.5,loc='lower right', fontsize=lts)
                        plt.title(mytitle[axnum-1], fontsize=ts)


                        

            savefig(readin + 'Figures_'+ where + '/' + more + '_' + conf_type + str(longer) +'_AUC_'+kick_weird+'_'+str(cutter)+'_'+tokens+'_pairwise'+str(pairs)+'_'+str(t)+'_two'+str(g)+'.pdf', bbox_inches='tight')
            plt.close()




   

            #set line widths
            ls = 4
            ts = 6
            lw = 1.5
            lts = 5.5
           

            events = dpdf


           
            if longer == 2:
                mytitle = ['All cases','Hard cases']
                myx = ['False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '']
                many = 3
            else:
                mytitle = ['All cases','Hard cases']
                myx = ['False Positive Rate','False Positive Rate']
                myy = ['True Positive Rate', '']

                many = 2

            for d in range(0,many):
                axnum = 0
                for p in range(0,2):
                
                    axnum = axnum + 1
                    
                    plt.subplot(1,2,axnum).set(adjustable='box', aspect='equal')
                    plt.plot(so, so, color='gray', lw=0.25, linestyle='--')
                    for i in range(0,len(labels)):
                        plt.plot(fpr[d,i,p], tpr[d,i,p],
                            label=labels[i]+' {0:0.2f}'
                            ''.format(roc_auc[d,i,p])+' [{0:0.2f},'
                            ''.format(roc_auc_CI_d[d,i,p])+'{0:0.2f}]'
                            ''.format(roc_auc_CI_u[d,i,p]),
                            color=colors[i], linestyle=lines[i], linewidth=lw)

                        plt.ylabel(myy[axnum-1], fontsize=ts)
                        plt.xlabel(myx[axnum-1], fontsize=ts)
                        plt.legend(frameon=False,fancybox=True, framealpha=0.5,loc='lower right', fontsize=lts)
                        plt.title(mytitle[axnum-1], fontsize=ts)


                        

                savefig(readin + 'Figures_'+ where + '/' + more + '_' + conf_type + str(longer) +'_AUC_'+kick_weird+'_'+str(cutter)+'_'+tokens+'_pairwise'+str(pairs)+'_'+events[d]+str(t)+'_'+str(g)+'.pdf', bbox_inches='tight')
                plt.close()


        
#plotting_rocs(30, '1_3','200', '0', 'superall', 'RandomForest', 2,1,'/Users/christopherrauh/Dropbox/JEEA_replication/', ['ged_best_ns', 'ged_best_os', 'ged_best_sb'])        







